import numpy as np
import math
from scipy.spatial.transform import Rotation
from scipy.interpolate import splprep, splev

def generate_ellipse_trajectory(semi_major_axis, semi_minor_axis, num_points):
    trans = []

    ts = np.linspace(0, 2*np.pi, num_points)
    for t in ts:
        x = semi_major_axis * np.cos(t)
        y = semi_minor_axis * np.sin(t)
        z = np.zeros_like(x)
        trans.append(np.array([x, y, z]).T)

    # Generate random quaternions for each point
    quats = []
    # first euler
    euler = np.random.uniform(-math.pi/8, math.pi/8, size=(3,))
    for i in range((int)(num_points/10)):
        add = np.random.uniform(-2*math.pi/num_points, 2*math.pi/num_points, size=(3,))
        euler += add

        r = Rotation.from_euler('xyz', euler)
        q = r.as_quat()
        quats.append(q)

    # Smooth the quaternions using spline interpolation
    interp_quats = smooth_quaternions(quats, 10)
    return trans, interp_quats

def smooth_quaternions(quat, times):
    '''
    # Perform spline interpolation on each quaternion component
    t = np.arange(len(quat))
    interp_quat = []
    for i in range(4):  # Interpolate each component separately
        component = [q[i] for q in quat]
        splines = splprep([t, component], s=0)
        interp_component = splev(t, splines)
        interp_quat.append(interp_component[1])

    # Transpose interpolated quaternion components
    interp_quat = np.array(interp_quat).T.tolist()
    return interp_quat



    interp_quat = []
    # Perform spline interpolation on Euler angles
    t = np.arange(len(euler))

    for i in range(3):
        print(i)
        print(type((np.array(euler).T)[i]))
        print((np.array(euler).T)[i].shape)
        splines = splprep((np.array(euler).T)[i], s=0)

        interp_t = np.linspace(0, len(euler) - 1, len(euler) * 10)
        interp_euleri = splev(interp_t, splines)
        interp_quat = np.stack(interp_quat, interp_quati)

    '''
    # Convert quaternions to Euler angles
    euler = []
    for q in quat:
        r = Rotation.from_quat(q)
        euler.append(r.as_euler('xyz'))

    # Perform spline interpolation on Euler angles
    #t = np.arange(len(euler))

    splines, mt = splprep(np.array(euler).T, s=0)

    interp_t = np.linspace(0, 1, len(euler)*times)
    interp_euler = splev(interp_t, splines)

    # Convert interpolated Euler angles back to quaternions
    interp_quat = []
    for euler in np.array(interp_euler).T:
        r = Rotation.from_euler('xyz', euler)
        q = r.as_quat()
        interp_quat.append(q)
    return interp_quat


'''
s S t
'''
def transform_trajectory(trans, quats, rotation, translation, s):

    r = rotation.as_matrix()
    t = translation

    rotated_quat = []
    transformed_trans = []
    for i in range(len(trans)):
        r1 = Rotation.from_quat(quats[i]).as_matrix()
        r2 = r1 @ r.T 
        rotated_quat.append(Rotation.from_matrix(r2).as_quat())

        t1 = trans[i]
        t2  = s * r @ t1 + t 
        transformed_trans.append(t2)

    return transformed_trans, rotated_quat




#add dirft, s, S and deltaT between the two body frame
'''
pose:[R,t] Twc = [RT, t]
T2wc = T21w*T1wc*{dT21c}-1
dT: [S, dR, dt]
T:[s, S, R, t]

R2 = sS*dR*R1*RT*S
t2 = -sS*R*R1T*dRT*S*dt + sS*R*t1 + t
'''
def Stransform_trajectory(trans, quats, rotation, translation, s, Is_add_drift = True):

    S = np.identity(3, dtype = int);

    dr_angle = np.random.uniform(-math.pi/4, math.pi/4)
    dt = np.random.uniform(-1, 1, size=(3,))
    # Apply random rotation
    dr_vector = np.random.uniform(-1, 1, size=(3,))
    dr_vector /= np.linalg.norm(dr_vector)
    dr = Rotation.from_rotvec(dr_angle * dr_vector).as_matrix()

    r = rotation.as_matrix()
    t = translation

    drift_r_euler = [0, 0, 0]
    drift_t = [0, 0, 0]

    rmse_r = [0, 0, 0]
    rmse_t = 0

    rotated_quat = []
    transformed_trans = []
    for i in range(len(trans)):
        if Is_add_drift:
            noise_r = np.random.normal(0, 0.001, 3);
            drift_r_euler += noise_r
            noise_t = np.random.normal(0, 0.001, 3);
            drift_t += noise_t
            rmse_r += drift_r_euler * drift_r_euler
            rmse_t += drift_t.T @ drift_t



        drift_r = Rotation.from_euler('xyz', drift_r_euler).as_matrix()

        r1 = Rotation.from_quat(quats[i]).as_matrix()
        r2 = S @ dr @ drift_r @ r1 @ r.T @ S 

        t1 = trans[i]
        t2  = -s*S @ r @ r1.T @ dr.T @ S @ dt + s*S @ r @ t1 + t + drift_t 

        rotated_quat.append(Rotation.from_matrix(r2).as_quat())
        transformed_trans.append(t2)

    rmse_r /= len(quats)
    rmse_t /= len(trans)

    print("rmse: rx ry rz t: %d %d %d %d", rmse_r[0], rmse_r[1], rmse_r[2], rmse_t)
    return transformed_trans, rotated_quat



def save_tum_trajectory(filename, trans, quats):
    with open(filename, 'w') as f:
        for i in range(len(trans)):
            line = f"{i} {trans[i][0]} {trans[i][1]} {trans[i][2]} {quats[i][0]} {quats[i][1]} {quats[i][2]} {quats[i][3]}\n"
            f.write(line)


if __name__ == "__main__":
    semi_major_axis = 2.0
    semi_minor_axis = 1.0
    num_points = 1000
    filename = "ellipse_trajectory.txt"
    #Tt = [sSR t] Ts
    filename_transformed = "ellipse_trajectory_transformed.txt"
    #add drift s S and deltaT between two body frame
    filename_Stransformed = "ellipse_trajectory_Stransformed.txt"
    Is_add_drift = True

    trans, quats = generate_ellipse_trajectory(semi_major_axis, semi_minor_axis, num_points)
    save_tum_trajectory(filename, trans, quats)

    # Transform original ellipse trajectory
    rotation_angle = np.random.uniform(-math.pi/4, math.pi/4)
    translation = np.random.uniform(-1, 1, size=(3,))
    # Apply random rotation
    rotation_vector = np.random.uniform(-1, 1, size=(3,))
    rotation_vector /= np.linalg.norm(rotation_vector)
    rotation = Rotation.from_rotvec(rotation_angle * rotation_vector)
    s = np.random.uniform(0, 1)

    trans_transformed, quats_rotated = transform_trajectory(trans, quats, rotation, translation, s)

    save_tum_trajectory(filename_transformed, trans_transformed, quats_rotated)

    # Transform original ellipse trajectory
    rotation_angle = np.random.uniform(-math.pi/4, math.pi/4)
    translation = np.random.uniform(-1, 1, size=(3,))
    # Apply random rotation
    rotation_vector = np.random.uniform(-1, 1, size=(3,))
    rotation_vector /= np.linalg.norm(rotation_vector)
    rotation = Rotation.from_rotvec(rotation_angle * rotation_vector)
    s = np.random.uniform(0, 1)

    trans_transformed, quats_rotated = Stransform_trajectory(trans, quats, rotation, translation, s, Is_add_drift)

    save_tum_trajectory(filename_Stransformed, trans_transformed, quats_rotated)


#椭圆长半轴(semi_major_axis)；椭圆短半轴(semi_minor_axis)；数量(num_points)
